-
Notifications
You must be signed in to change notification settings - Fork 544
[common] Remove kvpacked and qkvpacked attention functions for every kernel type. #2287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[common] Remove kvpacked and qkvpacked attention functions for every kernel type. #2287
Conversation
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, no comments
|
I think this is similar to #2272 :) Yes, Jax needs a bit of fixing in order to get its attention working. |
|
Could you add the deprecation note for these qkvpacked/kvpacked APIs as we discussed offline please? Thanks. |
Signed-off-by: Pawel Gadzinski <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Refactors fused attention code by eliminating duplicate kvpacked and qkvpacked kernel functions, moving pointer arithmetic into helper functions at the common API layer.
Key Changes
- Adds helper functions (
make_tensor_view,calculate_qkv_stride,calculate_kv_stride) to handle tensor unpacking - Deprecated packed API functions now unpack QKV/KV tensors and call non-packed kernel implementations
- Removes ~1400 lines of duplicate kernel code from arbitrary_seqlen implementation
- Adds deprecation warnings to packed API functions in header file
Issues Found
- Critical bug in KV-packed max512 stride calculation: Lines 840 and 977 incorrectly use
h_qinstead ofh_kvfor calculating byte offset between K and V tensors. This will cause incorrect memory access in Grouped Query Attention (GQA) scenarios where h_q ≠ h_kv.
Confidence Score: 2/5
- This PR contains critical bugs in stride calculations that will cause memory corruption in GQA scenarios
- The refactoring approach is sound and removes significant code duplication, but contains a critical logic error in KV-packed tensor stride calculations (lines 840, 977) where
h_qis used instead ofh_kv. This bug exists in both forward and backward passes for max512 kernels and will cause incorrect memory access when h_q ≠ h_kv (Grouped Query Attention). While the bug may not manifest in standard attention where h_q == h_kv, it represents a correctness issue that must be fixed before merge. - transformer_engine/common/fused_attn/fused_attn.cpp requires immediate attention - fix stride calculations on lines 840 and 977
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/fused_attn/fused_attn.cpp | 2/5 | Refactors packed attention functions to use helper functions and tensor views. Contains critical bugs in KV-packed stride calculations (lines 840, 977) where h_q is used instead of h_kv. |
| transformer_engine/common/include/transformer_engine/fused_attn.h | 5/5 | Adds deprecation warnings to packed QKV/KV API functions, directing users to use separate Q, K, V tensors instead. |
Sequence Diagram
sequenceDiagram
participant User as JAX/PyTorch
participant API as nvte_fused_attn_*_kvpacked (Deprecated)
participant Helper as Helper Functions
participant Kernel as fused_attn_max_512_* (Non-packed)
Note over User,Kernel: BEFORE: Separate functions for each kernel type
User->>API: nvte_fused_attn_fwd_kvpacked(KV_packed)
API->>Kernel: fused_attn_max_512_fwd_kvpacked(KV_packed)
Note over Kernel: Unpacks KV internally<br/>stride = 2*h*d
Kernel-->>API: result
API-->>User: result
Note over User,Kernel: AFTER: Unified using helper functions
User->>API: nvte_fused_attn_fwd_kvpacked(KV_packed)
API->>Helper: calculate_kv_stride(layout, dtype, h_kv, d)
Helper-->>API: stride (in bytes)
API->>Helper: make_tensor_view(KV, shape, offset=0)
Helper-->>API: K_view
API->>Helper: make_tensor_view(KV, shape, offset=stride)
Helper-->>API: V_view
API->>Kernel: fused_attn_max_512_fwd(Q, K_view, V_view)
Note over Kernel: Works with unpacked views
Kernel-->>API: result
API-->>User: result
2 files reviewed, 2 comments
| bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, | ||
| input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); | ||
| // Unpack KV and call the non-packed function | ||
| size_t stride = 2 * h_q * d; // For max512, KV layout is BS2HD or SB2HD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: stride calculation uses h_q but should use h_kv for KV-packed tensors
For KV-packed layout BS2HD, the memory contains K[h_kv, d] followed by V[h_kv, d]. The stride to skip from K to V should be based on h_kv, not h_q (which can differ in GQA scenarios).
| size_t stride = 2 * h_q * d; // For max512, KV layout is BS2HD or SB2HD | |
| size_t stride = 2 * h_kv * d; // For max512, KV layout is BS2HD or SB2HD |
| input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); | ||
|
|
||
| // Unpack KV and dKV and call the non-packed function | ||
| size_t stride = 2 * h_q * d; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: stride calculation uses h_q but should use h_kv for KV-packed tensors (backward pass)
Same issue as in forward pass - the stride should be based on h_kv for KV-packed tensors.
| size_t stride = 2 * h_q * d; | |
| size_t stride = 2 * h_kv * d; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's good point - this is some leftover from my previous change. Weird that tests didn't catch it. Fixed.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR consolidates fused attention implementations by removing duplicate QKV-packed and KV-packed kernel variants across three backend types (max512, arbitrary seqlen, FP8). The refactoring moves tensor unpacking logic from kernel-level to the common API layer in fused_attn.cpp, reducing ~1,388 lines of duplicated code.
Key changes:
- Removes
*_qkvpackedand*_kvpackedfunctions from all three kernel types - Adds helper functions for stride calculation and tensor view creation
- Unpacks QKV/KV tensors in the common API layer before calling unified kernel functions
- Deprecates packed QKV/KV APIs in the public header
Critical issue found:
- KV-packed max512 paths (lines 842, 979 in
fused_attn.cpp) useh_qinstead ofh_kvfor stride calculation, causing incorrect pointer arithmetic in Grouped Query Attention (GQA) scenarios whereh_q != h_kv
Confidence Score: 2/5
- Critical stride calculation bug will cause memory corruption in GQA with max512 backend
- Good refactoring approach, but incorrect stride calculation using
h_qinstead ofh_kvin max512 KV-packed paths will access wrong memory addresses when query and key/value have different head counts (GQA scenarios) transformer_engine/common/fused_attn/fused_attn.cpplines 842 and 979 - incorrect stride calculation for KV-packed tensors
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/fused_attn/fused_attn.cpp | 2/5 | Major refactoring that consolidates packed variants into unpacking logic. Critical bug: KV-packed max512 forward/backward use incorrect stride calculation (h_q instead of h_kv) for GQA scenarios. |
| transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 5/5 | Removes duplicate QKV/KV-packed function implementations, keeping only the unpacked variants. Clean removal of ~530 lines of duplicate code. |
| transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu | 5/5 | Removes QKV/KV-packed implementations (~264 lines), retaining only unpacked function. Good code deduplication. |
| transformer_engine/common/fused_attn/fused_attn_fp8.cu | 5/5 | Removes FP8 packed variant implementations (~418 lines). Good deduplication, shifts unpacking to higher level. |
Sequence Diagram
sequenceDiagram
participant User as JAX/PyTorch User
participant API as nvte_fused_attn_fwd_kvpacked<br/>(common API - fused_attn.cpp)
participant Helper as Helper Functions<br/>(calculate_kv_stride, etc.)
participant Kernel as Kernel Implementation<br/>(fused_attn_*_fwd)
User->>API: Call with packed KV tensor<br/>[b, s, 2, h_kv, d]
Note over API: Determine backend type<br/>(max512/arbitrary/fp8)
alt max512 backend
API->>Helper: Calculate stride for unpacking
Note over Helper: Should use h_kv, not h_q!<br/>(Bug at lines 842, 979)
Helper-->>API: stride = 2 * h_q * d (WRONG)
else arbitrary/fp8 backend
API->>Helper: calculate_kv_stride(h_kv, d)
Helper-->>API: stride = (bits * h_kv * d) / 8 (CORRECT)
end
API->>API: make_tensor_view(KV, shape, 0)<br/>→ K_view
API->>API: make_tensor_view(KV, shape, stride)<br/>→ V_view
API->>Kernel: fused_attn_*_fwd(Q, K_view, V_view)
Note over Kernel: Operates on unpacked<br/>Q, K, V tensors
Kernel-->>API: Output O
API-->>User: Return attention output
8 files reviewed, no comments
Signed-off-by: Pawel Gadzinski <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR successfully refactors fused attention functions by eliminating code duplication across three kernel types (arbitrary seqlen, max512, and fp8). The key changes include:
- New helper functions: Added
make_tensor_view(),calculate_qkv_stride(),calculate_kv_stride(), and shape calculation helpers in an anonymous namespace - Unpacking moved up: KV/QKV tensor unpacking now happens in the common API layer (
nvte_fused_attn_fwd_qkvpacked,nvte_fused_attn_fwd_kvpacked) instead of within each kernel-specific function - Kernel simplification: Removed
_qkvpackedand_kvpackedvariants of kernel functions (fused_attn_max_512_fwd_qkvpacked,fused_attn_max_512_fwd_kvpacked, etc.), reducing code duplication - Bug fix: Correctly uses
h_kvinstead ofh_qfor KV-packed stride calculations, fixing a bug in GQA (Grouped Query Attention) scenarios where the number of KV heads differs from Q heads - APIs marked deprecated: Added deprecation notices to
nvte_fused_attn_fwd_qkvpackedandnvte_fused_attn_fwd_kvpackedfunctions
The refactoring maintains functional equivalence while improving code maintainability. The stride calculations are now correct for all layout types (NVTE_HD_2HD, NVTE_HD_H2D, NVTE_3HD, NVTE_H3D).
Confidence Score: 5/5
- This PR is safe to merge with high confidence - it's a well-executed refactoring that improves code quality and fixes stride calculation bugs
- The refactoring is clean and well-structured with proper helper functions. The stride calculations correctly use
h_kvfor KV-packed tensors, fixing potential bugs in GQA scenarios. All unpacking logic has been centralized with consistent patterns across forward/backward passes and all kernel types. The changes maintain backward compatibility by keeping deprecated API functions. - No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/fused_attn/fused_attn.cpp | 5/5 | Major refactoring that removes code duplication by moving KV/QKV unpacking logic from kernel-specific functions to the common API layer. Correctly fixes stride calculation bugs for GQA scenarios where h_q != h_kv. |
Sequence Diagram
sequenceDiagram
participant User as JAX/PyTorch API
participant CommonAPI as nvte_fused_attn_fwd_kvpacked
participant Helpers as Helper Functions
participant Kernel as fused_attn_max_512_fwd
User->>CommonAPI: Call with packed KV tensor
CommonAPI->>Helpers: calculate_kv_stride(layout_group, dtype, h_kv, d)
Helpers-->>CommonAPI: stride (bytes)
CommonAPI->>Helpers: calculate_kv_unpacked_shape(KV, layout, h_kv, d)
Helpers-->>CommonAPI: unpacked_shape
CommonAPI->>Helpers: make_tensor_view(KV, shape, 0)
Helpers-->>CommonAPI: K_view
CommonAPI->>Helpers: make_tensor_view(KV, shape, stride)
Helpers-->>CommonAPI: V_view
CommonAPI->>Kernel: fused_attn_max_512_fwd(Q, K_view, V_view, ...)
Kernel-->>CommonAPI: output
CommonAPI-->>User: return output
1 file reviewed, no comments
|
/te-ci jax |
|
|
||
| /*! \brief Compute dot product attention with packed QKV input. | ||
| * | ||
| * \warning This API is **deprecated**. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would \deprecated be a better marking for this message in the document?
Could we add something like this to give user warnings at compile time as well?
[[deprecated("nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() instead.")]]
void nvte_fused_attn_fwd_qkvpacked() {
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, renders correctly and docs and compiles locally with warnings/
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR successfully removes code duplication in fused attention kernels by eliminating separate implementations for QKV-packed and KV-packed tensor layouts. The refactoring moves pointer arithmetic logic from kernel-specific functions (max512, arbitrary_seqlen, fp8) into helper functions at the common API layer.
Key changes:
- Added helper functions (
calculate_qkv_stride,calculate_kv_stride,make_tensor_view) to unpack QKV/KV tensors - Deprecated packed API functions now create tensor views and delegate to unified non-packed implementations
- Removed ~1,388 lines of duplicate kernel code across three attention backends
- Stride calculations correctly use
h_kvfor KV-packed tensors (fixed in commit ec7531d)
Impact:
- Reduces maintenance burden and potential for bugs from code duplication
- JAX users will continue to use deprecated packed APIs (with warnings) until migration
- PyTorch users unaffected (already use non-packed APIs)
Confidence Score: 5/5
- This PR is safe to merge with high confidence - it's a well-executed refactoring that eliminates code duplication without changing behavior
- Score reflects thorough code review showing correct stride calculations, proper pointer arithmetic using helper functions, and elimination of 1,388 lines of duplicate code. The previous stride calculation bug (h_q vs h_kv) was already fixed in commit ec7531d. Deprecation warnings properly guide users to non-packed APIs.
- No files require special attention - all implementations correctly use helper functions with appropriate parameters
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/include/transformer_engine/fused_attn.h | 5/5 | Added deprecation warnings to packed QKV/KV attention functions, guiding users to use separate tensor variants instead |
| transformer_engine/common/fused_attn/fused_attn.cpp | 5/5 | Added helper functions to unpack QKV/KV tensors and refactored deprecated packed APIs to call non-packed implementations, significantly reducing code duplication |
| transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu | 5/5 | Removed duplicate qkvpacked and kvpacked function implementations, now only retains the unified non-packed version |
| transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 5/5 | Removed duplicate qkvpacked and kvpacked function implementations for arbitrary sequence length kernels |
| transformer_engine/common/fused_attn/fused_attn_fp8.cu | 5/5 | Removed duplicate qkvpacked and kvpacked function implementations for FP8 attention kernels |
Sequence Diagram
sequenceDiagram
participant User as JAX/PyTorch
participant API as nvte_fused_attn_fwd_qkvpacked<br/>(deprecated)
participant Helper as Helper Functions
participant Kernel as fused_attn_*_fwd<br/>(unified implementation)
User->>API: Call with packed QKV tensor
Note over API: Extract layout_group, h, d<br/>from packed tensor
API->>Helper: calculate_qkv_stride(layout_group, dtype, h, d)
Helper-->>API: Returns byte offset for K, V
API->>Helper: calculate_qkv_unpacked_shape(qkv, h, d)
Helper-->>API: Returns unpacked shape
API->>Helper: make_tensor_view(qkv, shape, 0)<br/>Create Q view
Helper-->>API: Q_view
API->>Helper: make_tensor_view(qkv, shape, stride)<br/>Create K view
Helper-->>API: K_view
API->>Helper: make_tensor_view(qkv, shape, 2*stride)<br/>Create V view
Helper-->>API: V_view
API->>Kernel: Call with separate Q_view, K_view, V_view
Note over Kernel: Process attention with<br/>separate Q, K, V tensors
Kernel-->>API: Output
API-->>User: Return result
1 file reviewed, no comments
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR successfully refactors the fused attention implementation by removing duplicate code for packed QKV/KV tensor formats. The refactoring moves pointer arithmetic from individual kernel implementations into the common C++ API layer.
Key Changes:
- Removes ~1,400 lines of duplicate code across 3 kernel types (max512, arbitrary seqlen, FP8)
- Adds helper functions to calculate strides and create tensor views for unpacking
- Marks packed variants as deprecated with
[[deprecated]]attributes - Maintains backward compatibility by keeping the deprecated API functions
Implementation:
The deprecated nvte_fused_attn_fwd_qkvpacked and nvte_fused_attn_fwd_kvpacked functions now unpack tensors using:
calculate_qkv_stride/calculate_kv_stride- calculates byte offsets between Q/K/V in memorymake_tensor_view- creates tensor views with pointer offsets- Calls into existing non-packed kernel implementations
Testing:
JAX uses the packed variants, so JAX CI validates this refactoring. PyTorch uses only non-packed functions.
Confidence Score: 4/5
- This PR is safe to merge with low risk - it's a pure refactoring that eliminates code duplication
- Score reflects a well-executed refactoring with proper deprecation strategy. The implementation correctly handles stride calculations and tensor views. Minor deduction due to the removal of 1400+ lines that were functionally tested individually, though the refactored code maintains equivalent behavior. JAX CI provides good test coverage for the packed variants.
- No files require special attention - all changes are straightforward refactoring
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/fused_attn/fused_attn.cpp | 4/5 | Refactors packed QKV/KV functions to reuse unpacked implementations via tensor views; adds helper functions for stride calculation and shape unpacking |
| transformer_engine/common/include/transformer_engine/fused_attn.h | 5/5 | Marks packed QKV/KV functions as deprecated with [[deprecated]] attributes and documentation updates |
| transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 5/5 | Removes implementations of _qkvpacked and _kvpacked function variants (530 lines removed) |
| transformer_engine/common/fused_attn/fused_attn_fp8.cu | 5/5 | Removes implementations of _qkvpacked and _kvpacked FP8 function variants (418 lines removed) |
Sequence Diagram
sequenceDiagram
participant Caller as JAX/User Code
participant API as nvte_fused_attn_fwd_kvpacked (deprecated)
participant Helper as Helper Functions
participant Core as fused_attn_max_512_fwd / arbitrary / fp8
Caller->>API: Call with packed KV tensor
Note over API: KV shape: [B,S,2,H,D]
API->>Helper: calculate_kv_stride(layout, dtype, h_kv, d)
Helper-->>API: stride in bytes
API->>Helper: calculate_kv_unpacked_shape(KV, layout, format, ...)
Helper-->>API: unpacked shape [B,S,H,D]
API->>Helper: make_tensor_view(KV, shape, offset=0)
Helper-->>API: K_view (points to K data)
API->>Helper: make_tensor_view(KV, shape, offset=stride)
Helper-->>API: V_view (points to V data)
API->>Core: Call with separate K_view, V_view tensors
Core-->>API: Attention output
API-->>Caller: Return output
8 files reviewed, no comments
…kernel type. (NVIDIA#2287) * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * depracted compile time warning + \warning -> \deprecated Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <[email protected]> Signed-off-by: Peter Dykas <[email protected]>
Description
There are 3 variants of fused_attention functions: for separate QKV, KV packed and QKV packed, which differ only by pointers to qkv. This results in code duplication for each type of the fused attention kernel: arbitrary seqlen, max 512 and fp8. This PR deduplicates the code and moves pointer computation one abstraction layer - from the functions like
fused_attn_max_512_fwd_qkvpackedinto the functions likenvte_fused_attn_fwd_qkvpackedin common c++ api.These packed versions of common attention api functions are used by JAX, so I think running JAX CI is good test of that changes. PyTorch uses only non-packed function.
Type of change
Checklist: